{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.init as init\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "from six.moves import cPickle as pickle\n",
    "from six.moves import range"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train set (97722, 1, 64, 64) (97722, 6)\n",
      "40006\n"
     ]
    }
   ],
   "source": [
    "pickle_file = 'SVHN_1x64x64_train.pickle'\n",
    "\n",
    "with open(pickle_file, 'rb') as f:\n",
    "    save = pickle.load(f)\n",
    "    train_dataset = save['train_dataset']\n",
    "    train_labels = save['train_labels']\n",
    "    del save  # hint to help gc free up memory\n",
    "    print('train set', train_dataset.shape, train_labels.shape)\n",
    "\n",
    "c3 = []\n",
    "\n",
    "for i in range(97722):\n",
    "    categ = train_labels[i][0]\n",
    "    if(categ == 2):\n",
    "        c3.append(i)\n",
    "\n",
    "print(len(c3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(40006, 1, 64, 64)\n",
      "(40006, 6)\n"
     ]
    }
   ],
   "source": [
    "c3_data = train_dataset[c3]\n",
    "c3_target = train_labels[c3]\n",
    "print(c3_data.shape)\n",
    "print(c3_target.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 20, 3, padding=(1, 1))\n",
    "        self.conv2 = nn.Conv2d(20, 40, 3, padding=(1, 1))\n",
    "        self.conv3 = nn.Conv2d(40, 80, 3, padding=(1, 1))\n",
    "        self.conv4 = nn.Conv2d(80, 120, 3, padding=(1, 1))\n",
    "        self.conv5 = nn.Conv2d(120, 160, 3, padding=(1, 1))\n",
    "        self.conv6 = nn.Conv2d(160, 200, 3, padding=(1, 1))\n",
    "        self.conv7 = nn.Conv2d(200, 240, 3, padding=(1, 1))\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.FC = nn.Linear(960, 1080)\n",
    "        self.digitlength = nn.Linear(1080, 7)\n",
    "        self.digit1 = nn.Linear(1080, 10)\n",
    "        self.digit2 = nn.Linear(1080, 10)\n",
    "        self.digit3 = nn.Linear(1080, 10)\n",
    "        self.digit4 = nn.Linear(1080, 10)\n",
    "        self.digit5 = nn.Linear(1080, 10)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = self.pool(F.relu(self.conv4(x)))\n",
    "        x = self.pool(F.relu(self.conv5(x)))\n",
    "        x = self.pool(F.relu(self.conv6(x)))\n",
    "        x = self.pool(F.relu(self.conv7(x)))\n",
    "        x = x.view(-1, 960)\n",
    "        x = self.FC(x)\n",
    "        yl = self.digitlength(x)\n",
    "        y1 = self.digit1(x)\n",
    "        y2 = self.digit2(x)\n",
    "        y3 = self.digit3(x)\n",
    "        y4 = self.digit4(x)\n",
    "        y5 = self.digit5(x)\n",
    "        return [yl, y1, y2, y3, y4, y5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net (\n",
       "  (conv1): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv2): Conv2d(20, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv3): Conv2d(40, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv4): Conv2d(80, 120, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv5): Conv2d(120, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv6): Conv2d(160, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv7): Conv2d(200, 240, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (pool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
       "  (FC): Linear (960 -> 1080)\n",
       "  (digitlength): Linear (1080 -> 7)\n",
       "  (digit1): Linear (1080 -> 10)\n",
       "  (digit2): Linear (1080 -> 10)\n",
       "  (digit3): Linear (1080 -> 10)\n",
       "  (digit4): Linear (1080 -> 10)\n",
       "  (digit5): Linear (1080 -> 10)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = Net()\n",
    "f = open('c3_leg_ep_20.pkl', 'rb')\n",
    "net.load_state_dict(torch.load(f))\n",
    "f.close()\n",
    "net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for param in net.parameters():\n",
    "    if(param.grad is not None):\n",
    "        print(param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.FloatTensor torch.Size([40006, 1, 64, 64])\n",
      "torch.LongTensor torch.Size([40006, 6])\n"
     ]
    }
   ],
   "source": [
    "c3_data_tensor = torch.from_numpy(c3_data)\n",
    "c3_target_tensor = torch.from_numpy(c3_target).type(torch.LongTensor)\n",
    "print(c3_data_tensor.type(), c3_data_tensor.size())\n",
    "print(c3_target_tensor.type(), c3_target_tensor.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "objective = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(net.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "625\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 20\n",
    "batch_size = 64\n",
    "num_train =  c3_data.shape[0]\n",
    "iter_per_epoch = num_train // batch_size\n",
    "print_every = 300\n",
    "print(iter_per_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "epoch_losses = {i:[] for i in range(num_epochs)}\n",
    "loss_history = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/19\n",
      "--------------------------------------------------------------------------------------------------------------\n",
      "Iteration :  1  /  625\n",
      "loss :  93.4910888671875\n",
      "lossl :  84.45161437988281 loss1 :  0.5001339316368103 loss2 :  8.539336204528809\n",
      "Iteration :  301  /  625\n",
      "loss :  0.7808812856674194\n",
      "lossl :  0.000708363950252533 loss1 :  0.35415735840797424 loss2 :  0.42601555585861206\n",
      "Iteration :  601  /  625\n",
      "loss :  0.540520966053009\n",
      "lossl :  0.0006491988897323608 loss1 :  0.22309759259223938 loss2 :  0.3167741894721985\n",
      "time taken :  979.2570998668671\n",
      "--------------------------------------------------------------------------------------------------------------\n",
      "Epoch 1/19\n",
      "--------------------------------------------------------------------------------------------------------------\n",
      "Iteration :  1  /  625\n",
      "loss :  0.14915671944618225\n",
      "lossl :  0.00018310546875 loss1 :  0.07522555440664291 loss2 :  0.07374805212020874\n",
      "Iteration :  301  /  625\n",
      "loss :  0.34971943497657776\n",
      "lossl :  0.00036091357469558716 loss1 :  0.12418654561042786 loss2 :  0.2251719832420349\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-12-f61001f69b59>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     31\u001b[0m         \u001b[0mfinal_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     35\u001b[0m         \u001b[0mloss_history\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfinal_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sreekar/miniconda3/envs/deep/lib/python3.5/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m     43\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     44\u001b[0m                     \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m                 \u001b[0mgrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     46\u001b[0m                 \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):  # loop over the dataset multiple times\n",
    "    \n",
    "    print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "    print('-' * 110)\n",
    "\n",
    "    i = 0\n",
    "    rng_state = torch.get_rng_state()\n",
    "    new_idxs = torch.randperm(num_train)\n",
    "    c3_X = c3_data_tensor[new_idxs]\n",
    "    c3_Y = c3_target_tensor[new_idxs]\n",
    "    \n",
    "    t1 = time.time()\n",
    "    for t in range(iter_per_epoch):\n",
    "        \n",
    "        X_batch = c3_X[i: i+batch_size]\n",
    "        Y_batch = c3_Y[i: i+batch_size][:,0:4]\n",
    "        i += batch_size\n",
    "\n",
    "        Y_batch = Variable(Y_batch.cuda())\n",
    "        X_batch = Variable(X_batch.cuda())\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        outputs = net(X_batch)\n",
    "        \n",
    "        lossl = objective(outputs[0], Y_batch[:, 0])\n",
    "        loss1 = objective(outputs[1], Y_batch[:, 1])\n",
    "        loss2 = objective(outputs[2], Y_batch[:, 2])\n",
    "        final_loss = lossl + loss1 + loss2\n",
    "        \n",
    "        final_loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        loss_history.append(final_loss.data[0])\n",
    "        epoch_losses[epoch].append(final_loss.data[0])\n",
    "        \n",
    "        if (t % print_every == 0):\n",
    "            print('Iteration : ', t+1, ' / ', iter_per_epoch)\n",
    "            print('loss : ', final_loss.data[0])\n",
    "            print('lossl : ', lossl.data[0], 'loss1 : ', loss1.data[0], 'loss2 : ', loss2.data[0])\n",
    "    t2 = time.time()\n",
    "    print(\"time taken : \", t2-t1)\n",
    "    print('-' * 110)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f = open(\"c3-c2_ep_2.pkl\", \"bw\")\n",
    "torch.save(net.state_dict(), f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fc79b82d5c0>]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/sreekar/miniconda3/envs/deep/lib/python3.5/site-packages/matplotlib/font_manager.py:280: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.\n",
      "  'Matplotlib is building the font cache using fc-list. '\n"
     ]
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.plot(loss_history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
